-
Notifications
You must be signed in to change notification settings - Fork 528
add empirical sinkhorn and sinkhorn divergence functions #80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hello @kilianFatras and thank you for the PR, Could you please build the html documentation (folder /docs, execute make html) and check that all the doc for the new function compile OK? |
Hello, thank you for your answer. I updated ot.bregman file with a new doc and I also updated ot.stochastic with a new doc. |
ot/bregman.py
Outdated
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 | ||
''' | ||
|
||
sinkhorn_div = (2 * empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function fails if log=True
. as a matter of fact if log=true you should return a log containing all 3 logs + the loss for each of the 3 OT compitation.
We also need a test for log=True
test/test_bregman.py
Outdated
M_t = ot.dist(X_t, X_t) | ||
|
||
emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1) | ||
sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely add a test for log=True here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added 2 log tests, I also changed the sinkhorn div = ot.sinkhorn2(a, b, M, 1) - 1/2 * ot.sinkhorn2(a, a, M_s, 1) - 1/2 * ot.sinkhorn2(b, b, M_t, 1)
It seems that what |
Hello,
I am sending a PR. The PR has for purpose to add empirical functions. Mainly, the added functions just need the original source data, target data and the regularization parameter for entropic OT. In the PR, you will find:
Those functions will be in the bregman.py file. Their test functions will be in test_bregman.py. Finally, their examples have been put in plot_OT_2D_samples.py.